from collections.abc import Iterable
from math import isnan
from typing import Literal

import numpy as np
from einops import repeat
from jaxtyping import Float, Int


def round_values(
    values: Float[np.ndarray, "row col"],
    precisions: Int[np.ndarray, " col"],
) -> Float[np.ndarray, "row col"]:
    """Round values to the specified precision."""
    quantized = np.zeros_like(values)
    r, _ = values.shape
    precisions = repeat(precisions, "c -> r c", r=r)
    for precision in np.unique(precisions):
        mask = precisions == precision
        quantized[mask] = np.round(values[mask], precision)
    return quantized


def compute_ranks_for_column(
    values: Float[np.ndarray, " row"],
    order: Literal[-1, 0, 1],
) -> Int[np.ndarray, " row"]:
    # If the entries are unordered, return an arbitrary ranking.
    if order == 0:
        return np.full_like(values, 1e5, dtype=np.int32)

    # Handle NaNs.
    values = np.copy(values)
    values[np.isnan(values)] = -order * np.inf

    # Find and rank unique values.
    ranked_unique_values = np.sort(np.unique(values))
    if order == 1:
        ranked_unique_values = ranked_unique_values[::-1]

    # Assign ranks to the original entries.
    ranks = np.zeros_like(values, dtype=np.int32)
    for rank, value in enumerate(ranked_unique_values):
        ranks[values == value] = rank

    return np.int32(ranks)


def compute_ranks(
    values: Float[np.ndarray, "row col"],
    orders: Int[np.ndarray, " col"],
) -> Int[np.ndarray, "row col"]:
    ranks = np.zeros_like(values, dtype=np.int64)
    _, c = values.shape
    for col in range(c):
        ranks[:, col] = compute_ranks_for_column(values[:, col], orders[col].item())
    return ranks


TableRows = dict[str, list[float | None]]  # Maps method names to rows of results.
MultiHeaders = Iterable[tuple[str, int]]  # Each element is (text, num columns).


def make_latex_table(
    results: TableRows,
    metrics: list[str],
    precisions: list[int],
    rank_orders: list[Literal[-1, 0, 1]],
    none_str: str = "N/A",
    multi_headers: MultiHeaders | None = None,
) -> str:
    data = np.array(list(results.values()), dtype=np.float64)

    data_rounded = round_values(data, np.array(precisions))
    ranks = compute_ranks(data_rounded, np.array(rank_orders))

    rank_functions = (
        lambda x: f"\\first{{{x}}}",
        lambda x: f"\\second{{{x}}}",
        lambda x: f"\\third{{{x}}}",
        lambda x: x,
    )
    rank_symbols = {
        0: "",
        1: " $\\uparrow$",
        -1: " $\\downarrow$",
    }

    # Add arrows to the metric names.
    metrics = [
        f"{metric}{rank_symbols[rank_order]}"
        for metric, rank_order in zip(metrics, rank_orders)
    ]

    # Generate strings for the table cells.
    cells = [
        [
            method_name,
            *[
                rank_functions[min(ranks[row, col], len(rank_functions) - 1)](
                    none_str
                    if (value is None or isnan(value))
                    else f"{value:.{precisions[col]}f}"
                )
                for col, value in enumerate(row_values)
            ],
        ]
        for row, (method_name, row_values) in enumerate(results.items())
    ]

    # Add a row for the headers.
    cells = [["Method", *metrics], *cells]

    # Figure out the maximum width for each column.
    widths = np.array([[len(cell) for cell in row] for row in cells])
    max_widths = np.max(widths, axis=0)

    # Pad the cells to the maximum width.
    cells = [
        [
            (cell.rjust if row > 0 and col > 0 else cell.ljust)(max_widths[col])
            for col, cell in enumerate(row_cells)
        ]
        for row, row_cells in enumerate(cells)
    ]

    # Join the cells into LaTeX rows.
    rows = [" & ".join(row) + " \\\\" for row in cells]

    # Create the multi headers.
    if multi_headers is None:
        multi_header_rows = []
    else:
        columns = [
            f"\multicolumn{{{span}}}{{{'|c|' if i < len(multi_headers) - 1 else '|c'}}}"
            f"{{{text}}}"
            for i, (text, span) in enumerate(multi_headers)
        ]
        multi_header_rows = [
            " & ".join(("\multicolumn{1}{c|}{}", *columns)) + " \\\\",
            "\\midrule",
        ]

    # Add the rules.
    header_row, *other_rows = rows
    column_specifications = "r" * len(metrics)
    if multi_headers is not None:
        chunks = []
        i = 0
        for _, span in multi_headers:
            chunks.append("|")
            chunks.append(column_specifications[i : i + span])
            i += span
        column_specifications = "".join(chunks)

    rows = [
        "%%% BEGIN AUTOGENERATED %%%",
        f"\\begin{{tabular}}{{l{column_specifications}}}",
        "\\toprule",
        *multi_header_rows,
        header_row,
        "\\midrule",
        *other_rows,
        "\\bottomrule",
        "\\end{tabular}",
        "%%% END AUTOGENERATED %%%",
    ]
    return "\n".join(rows)
